import os

from config import ModelLayers
import importlib


def layer_class_dict_factory():
    _layer_class_dict = {}
    for file in os.listdir(os.path.dirname(__file__)):
        mod_name = file[:-3]  # 干掉.py这3个字符
        if mod_name[0] == "_":
            continue
        module = importlib.import_module('.' + mod_name, package=__name__)
        for p in module.__dict__:
            if mod_name.replace("_", "") == p.lower():
                layer_class = module.__getattribute__(p)
                _layer_class_dict[layer_class.__name__] = layer_class
                # globals()[layer_class.__name__] = layer_class
    return _layer_class_dict


def _layer_dict_factory():
    ld = {}
    _layer_class_dict = layer_class_dict_factory()
    for item in ModelLayers.__dict__:
        if item[0] == '_':
            continue
        ld[ModelLayers.__dict__.get(item)] = _layer_class_dict[str(item) + "Layer"]
        # ld[ModelLayers.__dict__.get(item)] = globals()[str(item) + "Layer"]
    return ld


layer_dict = _layer_dict_factory()
